# Copyright 2020 Google Research. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Test for create_pascal_tfrecord.py."""

import os

from absl import logging
import numpy as np
import PIL.Image
import six
import tensorflow as tf

from dataset import create_pascal_tfrecord


class CreatePascalTFRecordTest(tf.test.TestCase):

  def _assertProtoEqual(self, proto_field, expectation):
    """Helper function to assert if a proto field equals some value.

    Args:
      proto_field: The protobuf field to compare.
      expectation: The expected value of the protobuf field.
    """
    proto_list = [p for p in proto_field]
    self.assertListEqual(proto_list, expectation)

  def test_dict_to_tf_example(self):
    image_file_name = '2012_12.jpg'
    image_data = np.random.rand(256, 256, 3)
    save_path = os.path.join(self.get_temp_dir(), image_file_name)
    image = PIL.Image.fromarray(image_data, 'RGB')
    image.save(save_path)

    data = {
        'folder':
            '',
        'filename':
            image_file_name,
        'size': {
            'height': 256,
            'width': 256,
        },
        'object': [
            {
                'difficult': 1,
                'bndbox': {
                    'xmin': 64,
                    'ymin': 64,
                    'xmax': 192,
                    'ymax': 192,
                },
                'name': 'person',
                'truncated': 0,
                'pose': '',
            },
            {
                'difficult': 0,
                'bndbox': {
                    'xmin': 128,
                    'ymin': 128,
                    'xmax': 256,
                    'ymax': 256,
                },
                'name': 'notperson',
                'truncated': 0,
                'pose': '',
            },
        ],
    }

    label_map_dict = {
        'background': 0,
        'person': 1,
        'notperson': 2,
    }

    ann_json_dict = {'images': [], 'annotations': [], 'categories': []}
    unique_id = create_pascal_tfrecord.UniqueId()
    example = create_pascal_tfrecord.dict_to_tf_example(
        data,
        self.get_temp_dir(),
        label_map_dict,
        unique_id,
        ann_json_dict=ann_json_dict)
    self.assertEqual(unique_id.image_id, 1)
    self.assertEqual(unique_id.ann_id, 2)

    self._assertProtoEqual(
        example.features.feature['image/height'].int64_list.value, [256])
    self._assertProtoEqual(
        example.features.feature['image/width'].int64_list.value, [256])
    self._assertProtoEqual(
        example.features.feature['image/filename'].bytes_list.value,
        [six.b(image_file_name)])
    self._assertProtoEqual(
        example.features.feature['image/source_id'].bytes_list.value,
        [six.b(str(1))])
    self._assertProtoEqual(
        example.features.feature['image/format'].bytes_list.value,
        [six.b('jpeg')])
    self._assertProtoEqual(
        example.features.feature['image/object/bbox/xmin'].float_list.value,
        [0.25, 0.5])
    self._assertProtoEqual(
        example.features.feature['image/object/bbox/ymin'].float_list.value,
        [0.25, 0.5])
    self._assertProtoEqual(
        example.features.feature['image/object/bbox/xmax'].float_list.value,
        [0.75, 1.0])
    self._assertProtoEqual(
        example.features.feature['image/object/bbox/ymax'].float_list.value,
        [0.75, 1.0])
    self._assertProtoEqual(
        example.features.feature['image/object/class/text'].bytes_list.value,
        [six.b('person'), six.b('notperson')])
    self._assertProtoEqual(
        example.features.feature['image/object/class/label'].int64_list.value,
        [1, 2])
    self._assertProtoEqual(
        example.features.feature['image/object/difficult'].int64_list.value,
        [1, 0])
    self._assertProtoEqual(
        example.features.feature['image/object/truncated'].int64_list.value,
        [0, 0])
    self._assertProtoEqual(
        example.features.feature['image/object/view'].bytes_list.value,
        [six.b(''), six.b('')])

    expected_ann_json_dict = {
        'annotations': [{
            'area': 16384,
            'iscrowd': 0,
            'image_id': 1,
            'bbox': [64, 64, 128, 128],
            'category_id': 1,
            'id': 1,
            'ignore': 0,
            'segmentation': []
        }, {
            'area': 16384,
            'iscrowd': 0,
            'image_id': 1,
            'bbox': [128, 128, 128, 128],
            'category_id': 2,
            'id': 2,
            'ignore': 0,
            'segmentation': []
        }],
        'categories': [],
        'images': [{
            'file_name': '2012_12.jpg',
            'height': 256,
            'width': 256,
            'id': 1
        }]
    }
    self.assertEqual(ann_json_dict, expected_ann_json_dict)


if __name__ == '__main__':
  logging.set_verbosity(logging.WARNING)
  tf.test.main()
